# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#      http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from collections import OrderedDict

from torch.utils import data

import examples.torch.semantic_segmentation.utils.data as data_utils
import nncf


class CamVid(data.Dataset):
    """CamVid dataset loader where the dataset is arranged as in
    https://github.com/alexgkendall/SegNet-Tutorial/tree/master/CamVid.


    Keyword arguments:
    - root_dir (``string``): Root directory path.
    - mode (``string``): The type of dataset: 'train' for training set, 'val'
    for validation set, and 'test' for test set.
    - transform (``callable``, optional): A function/transform that  takes in
    an PIL image and returns a transformed version. Default: None.
    - label_transform (``callable``, optional): A function/transform that takes
    in the target and transforms it. Default: None.
    - loader (``callable``, optional): A function to load an image given its
    path. By default ``default_loader`` is used.

    """

    # Training dataset root folders
    train_folder = "train"
    train_lbl_folder = "trainannot"

    # Validation dataset root folders
    val_folder = "val"
    val_lbl_folder = "valannot"

    # Test dataset root folders
    test_folder = "test"
    test_lbl_folder = "testannot"

    # Images extension
    img_extension = ".png"

    # Default encoding for pixel value, class name, and class color
    color_encoding = OrderedDict(
        [
            ("sky", (128, 128, 128)),
            ("building", (128, 0, 0)),
            ("pole", (192, 192, 128)),
            ("road_marking", (255, 69, 0)),
            ("road", (128, 64, 128)),
            ("pavement", (60, 40, 222)),
            ("tree", (128, 128, 0)),
            ("sign_symbol", (192, 128, 128)),
            ("fence", (64, 64, 128)),
            ("car", (64, 0, 128)),
            ("pedestrian", (64, 64, 0)),
            ("bicyclist", (0, 128, 192)),
            ("unlabeled", (0, 0, 0)),
        ]
    )

    def __init__(self, root, image_set="train", transforms=None, loader=data_utils.pil_loader):
        super().__init__()
        self.root_dir = root
        self.mode = image_set
        self.transforms = transforms
        self.loader = loader

        if self.mode.lower() == "train":
            # Get the training data and labels filepaths
            self.train_data = data_utils.get_files(
                os.path.join(self.root_dir, self.train_folder), extension_filter=self.img_extension
            )

            self.train_labels = data_utils.get_files(
                os.path.join(self.root_dir, self.train_lbl_folder), extension_filter=self.img_extension
            )
        elif self.mode.lower() == "val":
            # Get the validation data and labels filepaths
            self.val_data = data_utils.get_files(
                os.path.join(self.root_dir, self.val_folder), extension_filter=self.img_extension
            )

            self.val_labels = data_utils.get_files(
                os.path.join(self.root_dir, self.val_lbl_folder), extension_filter=self.img_extension
            )
        elif self.mode.lower() == "test":
            # Get the test data and labels filepaths
            self.test_data = data_utils.get_files(
                os.path.join(self.root_dir, self.test_folder), extension_filter=self.img_extension
            )

            self.test_labels = data_utils.get_files(
                os.path.join(self.root_dir, self.test_lbl_folder), extension_filter=self.img_extension
            )
        else:
            raise nncf.ValidationError("Unexpected dataset mode. Supported modes are: train, val and test")

    def __getitem__(self, index):
        """
        Args:
        - index (``int``): index of the item in the dataset

        Returns:
        A tuple of ``PIL.Image`` (image, label) where label is the ground-truth
        of the image.

        """
        if self.mode.lower() == "train":
            data_path, label_path = self.train_data[index], self.train_labels[index]
        elif self.mode.lower() == "val":
            data_path, label_path = self.val_data[index], self.val_labels[index]
        elif self.mode.lower() == "test":
            data_path, label_path = self.test_data[index], self.test_labels[index]
        else:
            raise nncf.ValidationError("Unexpected dataset mode. Supported modes are: train, val and test")

        img, label = self.loader(data_path, label_path)

        if self.transforms is not None:
            img, label = self.transforms(img, label)

        return img, label

    def __len__(self):
        """Returns the length of the dataset."""
        if self.mode.lower() == "train":
            return len(self.train_data)
        if self.mode.lower() == "val":
            return len(self.val_data)
        if self.mode.lower() == "test":
            return len(self.test_data)

        raise nncf.ValidationError("Unexpected dataset mode. Supported modes are: train, val and test")
